Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][x86vector] AVX512-BF16 Convert packed F32 to BF16 #125685

Merged
merged 1 commit into from
Feb 18, 2025

Conversation

adam-smnk
Copy link
Contributor

Adds AVX512 bf16 conversion from packed f32 to bf16 elements.

Tests are slightly refactored to better follow file's convention.

Adds AVX512 bf16 conversion from packed f32 to bf16 elements.

Tests are slightly refactored to better follow file's convention.
@llvmbot
Copy link
Member

llvmbot commented Feb 4, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Adam Siemieniuk (adam-smnk)

Changes

Adds AVX512 bf16 conversion from packed f32 to bf16 elements.

Tests are slightly refactored to better follow file's convention.


Full diff: https://github.com/llvm/llvm-project/pull/125685.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/X86Vector/X86Vector.td (+40)
  • (modified) mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp (+40-2)
  • (added) mlir/test/Dialect/X86Vector/cvt-packed-f32-to-bf16.mlir (+24)
  • (modified) mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir (+18)
  • (modified) mlir/test/Dialect/X86Vector/roundtrip.mlir (+20)
  • (modified) mlir/test/Target/LLVMIR/x86vector.mlir (+29-9)
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 16181d7e760db5..566013e73f4b89 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -341,6 +341,46 @@ def DotBF16Ps512IntrOp : AVX512_IntrOp<"dpbf16ps.512", 1, [Pure,
   let results = (outs VectorOfLengthAndType<[16], [F32]>:$res);
 }
 
+//----------------------------------------------------------------------------//
+// Convert packed F32 to packed BF16
+//----------------------------------------------------------------------------//
+
+def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
+  AllElementCountsMatch<["a", "dst"]>]> {
+  let summary = "Convert packed F32 to packed BF16 Data.";
+  let description = [{
+    The `convert_f32_to_bf16` op is an AVX512-BF16 specific op that can lower
+    to the proper LLVMAVX512BF16 operation `llvm.cvtneps2bf16` depending on
+    the width of MLIR vectors it is applied to.
+
+    #### From the Intel Intrinsics Guide:
+
+    Convert packed single-precision (32-bit) floating-point elements in `a` to
+    packed BF16 (16-bit) floating-point elements, and store the results in `dst`.
+
+    Example:
+    ```mlir
+    %dst = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
+    ```
+  }];
+  let arguments = (ins VectorOfLengthAndType<[8, 16], [F32]>:$a);
+  let results = (outs VectorOfLengthAndType<[8, 16], [BF16]>:$dst);
+  let assemblyFormat =
+    "$a attr-dict `:` type($a) `->` type($dst)";
+}
+
+def CvtNeF32ToBF16Ps256IntrOp : AVX512_IntrOp<"cvtneps2bf16.256", 1, [Pure],
+    /*extension=*/"bf16"> {
+  let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
+  let results = (outs VectorOfLengthAndType<[8], [BF16]>:$res);
+}
+
+def CvtNeF32ToBF16Ps512IntrOp : AVX512_IntrOp<"cvtneps2bf16.512", 1, [Pure],
+    /*extension=*/"bf16"> {
+  let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$a);
+  let results = (outs VectorOfLengthAndType<[16], [BF16]>:$res);
+}
+
 //===----------------------------------------------------------------------===//
 // AVX op definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index 260ac9ce589a38..f1fbb39b97fc49 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -131,6 +131,39 @@ struct DotBF16OpConversion : public ConvertOpToLLVMPattern<DotBF16Op> {
   }
 };
 
+struct CvtPackedF32ToBF16Conversion
+    : public ConvertOpToLLVMPattern<CvtPackedF32ToBF16Op> {
+  using ConvertOpToLLVMPattern<CvtPackedF32ToBF16Op>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(CvtPackedF32ToBF16Op op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto typeA = dyn_cast<VectorType>(op.getA().getType());
+    unsigned elemBitWidth = typeA.getElementTypeBitWidth();
+    unsigned opBitWidth = typeA.getShape()[0] * elemBitWidth;
+
+    auto opType = op.getDst().getType();
+    auto opA = op.getA();
+
+    switch (opBitWidth) {
+    case 256: {
+      rewriter.replaceOpWithNewOp<CvtNeF32ToBF16Ps256IntrOp>(op, opType, opA);
+      break;
+    }
+    case 512: {
+      rewriter.replaceOpWithNewOp<CvtNeF32ToBF16Ps512IntrOp>(op, opType, opA);
+      break;
+    }
+    default: {
+      return rewriter.notifyMatchFailure(
+          op, "unsupported AVX512-BF16 packed f32 to bf16 variant");
+    }
+    }
+
+    return success();
+  }
+};
+
 struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> {
   using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;
 
@@ -202,8 +235,10 @@ using Registry = RegistryImpl<
 void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
   Registry::registerPatterns(converter, patterns);
-  patterns.add<MaskCompressOpConversion, DotBF16OpConversion, RsqrtOpConversion,
-               DotOpConversion>(converter);
+  patterns
+      .add<MaskCompressOpConversion, DotBF16OpConversion,
+           CvtPackedF32ToBF16Conversion, RsqrtOpConversion, DotOpConversion>(
+          converter);
 }
 
 void mlir::configureX86VectorLegalizeForExportTarget(
@@ -215,6 +250,9 @@ void mlir::configureX86VectorLegalizeForExportTarget(
   target.addLegalOp<DotBF16Ps256IntrOp>();
   target.addLegalOp<DotBF16Ps512IntrOp>();
   target.addIllegalOp<DotBF16Op>();
+  target.addLegalOp<CvtNeF32ToBF16Ps256IntrOp>();
+  target.addLegalOp<CvtNeF32ToBF16Ps512IntrOp>();
+  target.addIllegalOp<CvtPackedF32ToBF16Op>();
   target.addLegalOp<RsqrtIntrOp>();
   target.addIllegalOp<RsqrtOp>();
   target.addLegalOp<DotIntrOp>();
diff --git a/mlir/test/Dialect/X86Vector/cvt-packed-f32-to-bf16.mlir b/mlir/test/Dialect/X86Vector/cvt-packed-f32-to-bf16.mlir
new file mode 100644
index 00000000000000..c97c52f01c3b03
--- /dev/null
+++ b/mlir/test/Dialect/X86Vector/cvt-packed-f32-to-bf16.mlir
@@ -0,0 +1,24 @@
+// REQUIRES: target=x86{{.*}}
+
+// RUN: mlir-opt %s \
+// RUN:   -convert-vector-to-llvm="enable-x86vector" -convert-to-llvm \
+// RUN:   -reconcile-unrealized-casts | \
+// RUN: mlir-translate --mlir-to-llvmir | \
+// RUN: llc -mcpu=sapphirerapids | \
+// RUN: FileCheck %s
+
+func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
+    %a: vector<8xf32>) -> vector<8xbf16> {
+  %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
+  return %0 : vector<8xbf16>
+}
+// CHECK-LABEL: avx512bf16_cvt_packed_f32_to_bf16_256:
+// CHECK: vcvtneps2bf16{{.*}}%xmm
+
+func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
+    %a: vector<16xf32>) -> vector<16xbf16> {
+  %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
+  return %0 : vector<16xbf16>
+}
+// CHECK-LABEL: avx512bf16_cvt_packed_f32_to_bf16_512:
+// CHECK: vcvtneps2bf16{{.*}}%ymm
diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index ed9177eaec9ce4..59be7dd75b3b0b 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -70,6 +70,24 @@ func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
   return %0 : vector<16xf32>
 }
 
+// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_256
+func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
+  %a: vector<8xf32>) -> (vector<8xbf16>)
+{
+  // CHECK: x86vector.avx512.intr.cvtneps2bf16.256
+  %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
+  return %0 : vector<8xbf16>
+}
+
+// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_512
+func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
+  %a: vector<16xf32>) -> (vector<16xbf16>)
+{
+  // CHECK: x86vector.avx512.intr.cvtneps2bf16.512
+  %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
+  return %0 : vector<16xbf16>
+}
+
 // CHECK-LABEL: func @avx_rsqrt
 func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
 {
diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index cf74a7ee602558..0d00448c63da88 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -74,6 +74,26 @@ func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
   return %0 : vector<16xf32>
 }
 
+// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_256
+func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
+  %a: vector<8xf32>) -> (vector<8xbf16>)
+{
+  // CHECK: x86vector.avx512.cvt.packed.f32_to_bf16 {{.*}} :
+  // CHECK-SAME: vector<8xf32> -> vector<8xbf16>
+  %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
+  return %0 : vector<8xbf16>
+}
+
+// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_512
+func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
+  %a: vector<16xf32>) -> (vector<16xbf16>)
+{
+  // CHECK: x86vector.avx512.cvt.packed.f32_to_bf16 {{.*}} :
+  // CHECK-SAME: vector<16xf32> -> vector<16xbf16>
+  %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
+  return %0 : vector<16xbf16>
+}
+
 // CHECK-LABEL: func @avx_rsqrt
 func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
 {
diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir
index 1df03f10c93214..db1c10cd5cd37a 100644
--- a/mlir/test/Target/LLVMIR/x86vector.mlir
+++ b/mlir/test/Target/LLVMIR/x86vector.mlir
@@ -62,37 +62,57 @@ llvm.func @LLVM_x86_vp2intersect_q_512(%a: vector<8xi64>, %b: vector<8xi64>)
 
 // CHECK-LABEL: define <4 x float> @LLVM_x86_avx512bf16_dpbf16ps_128
 llvm.func @LLVM_x86_avx512bf16_dpbf16ps_128(
-    %arg0: vector<4xf32>, %arg1: vector<8xbf16>, %arg2: vector<8xbf16>
+    %src: vector<4xf32>, %a: vector<8xbf16>, %b: vector<8xbf16>
   ) -> vector<4xf32>
 {
   // CHECK: call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(
-  %0 = "x86vector.avx512.intr.dpbf16ps.128"(%arg0, %arg1, %arg2)
+  %0 = "x86vector.avx512.intr.dpbf16ps.128"(%src, %a, %b)
     : (vector<4xf32>, vector<8xbf16>, vector<8xbf16>) -> vector<4xf32>
   llvm.return %0 : vector<4xf32>
 }
 
 // CHECK-LABEL: define <8 x float> @LLVM_x86_avx512bf16_dpbf16ps_256
 llvm.func @LLVM_x86_avx512bf16_dpbf16ps_256(
-    %arg0: vector<8xf32>, %arg1: vector<16xbf16>, %arg2: vector<16xbf16>
+    %src: vector<8xf32>, %a: vector<16xbf16>, %b: vector<16xbf16>
   ) -> vector<8xf32>
 {
   // CHECK: call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(
-  %0 = "x86vector.avx512.intr.dpbf16ps.256"(%arg0, %arg1, %arg2)
+  %0 = "x86vector.avx512.intr.dpbf16ps.256"(%src, %a, %b)
     : (vector<8xf32>, vector<16xbf16>, vector<16xbf16>) -> vector<8xf32>
   llvm.return %0 : vector<8xf32>
 }
 
 // CHECK-LABEL: define <16 x float> @LLVM_x86_avx512bf16_dpbf16ps_512
 llvm.func @LLVM_x86_avx512bf16_dpbf16ps_512(
-    %arg0: vector<16xf32>, %arg1: vector<32xbf16>, %arg2: vector<32xbf16>
+    %src: vector<16xf32>, %a: vector<32xbf16>, %b: vector<32xbf16>
   ) -> vector<16xf32>
 {
   // CHECK: call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(
-  %0 = "x86vector.avx512.intr.dpbf16ps.512"(%arg0, %arg1, %arg2)
+  %0 = "x86vector.avx512.intr.dpbf16ps.512"(%src, %a, %b)
     : (vector<16xf32>, vector<32xbf16>, vector<32xbf16>) -> vector<16xf32>
   llvm.return %0 : vector<16xf32>
 }
 
+// CHECK-LABEL: define <8 x bfloat> @LLVM_x86_avx512bf16_cvtneps2bf16_256
+llvm.func @LLVM_x86_avx512bf16_cvtneps2bf16_256(
+  %a: vector<8xf32>) -> vector<8xbf16>
+{
+  // CHECK: call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256(
+  %0 = "x86vector.avx512.intr.cvtneps2bf16.256"(%a)
+    : (vector<8xf32>) -> vector<8xbf16>
+  llvm.return %0 : vector<8xbf16>
+}
+
+// CHECK-LABEL: define <16 x bfloat> @LLVM_x86_avx512bf16_cvtneps2bf16_512
+llvm.func @LLVM_x86_avx512bf16_cvtneps2bf16_512(
+  %a: vector<16xf32>) -> vector<16xbf16>
+{
+  // CHECK: call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512(
+  %0 = "x86vector.avx512.intr.cvtneps2bf16.512"(%a)
+    : (vector<16xf32>) -> vector<16xbf16>
+  llvm.return %0 : vector<16xbf16>
+}
+
 // CHECK-LABEL: define <8 x float> @LLVM_x86_avx_rsqrt_ps_256
 llvm.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
 {
@@ -103,11 +123,11 @@ llvm.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
 
 // CHECK-LABEL: define <8 x float> @LLVM_x86_avx_dp_ps_256
 llvm.func @LLVM_x86_avx_dp_ps_256(
-    %arg0: vector<8xf32>, %arg1: vector<8xf32>
+    %a: vector<8xf32>, %b: vector<8xf32>
   ) -> vector<8xf32>
 {
   // CHECK: call <8 x float> @llvm.x86.avx.dp.ps.256(
-  %0 = llvm.mlir.constant(-1 : i8) : i8
-  %1 = "x86vector.avx.intr.dp.ps.256"(%arg0, %arg1, %0) : (vector<8xf32>, vector<8xf32>, i8) -> vector<8xf32>
+  %c = llvm.mlir.constant(-1 : i8) : i8
+  %1 = "x86vector.avx.intr.dp.ps.256"(%a, %b, %c) : (vector<8xf32>, vector<8xf32>, i8) -> vector<8xf32>
   llvm.return %1 : vector<8xf32>
 }

@llvmbot
Copy link
Member

llvmbot commented Feb 4, 2025

@llvm/pr-subscribers-mlir-llvm

Author: Adam Siemieniuk (adam-smnk)

Changes

Adds AVX512 bf16 conversion from packed f32 to bf16 elements.

Tests are slightly refactored to better follow file's convention.


Full diff: https://github.com/llvm/llvm-project/pull/125685.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/X86Vector/X86Vector.td (+40)
  • (modified) mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp (+40-2)
  • (added) mlir/test/Dialect/X86Vector/cvt-packed-f32-to-bf16.mlir (+24)
  • (modified) mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir (+18)
  • (modified) mlir/test/Dialect/X86Vector/roundtrip.mlir (+20)
  • (modified) mlir/test/Target/LLVMIR/x86vector.mlir (+29-9)
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 16181d7e760db5f..566013e73f4b890 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -341,6 +341,46 @@ def DotBF16Ps512IntrOp : AVX512_IntrOp<"dpbf16ps.512", 1, [Pure,
   let results = (outs VectorOfLengthAndType<[16], [F32]>:$res);
 }
 
+//----------------------------------------------------------------------------//
+// Convert packed F32 to packed BF16
+//----------------------------------------------------------------------------//
+
+def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
+  AllElementCountsMatch<["a", "dst"]>]> {
+  let summary = "Convert packed F32 to packed BF16 Data.";
+  let description = [{
+    The `convert_f32_to_bf16` op is an AVX512-BF16 specific op that can lower
+    to the proper LLVMAVX512BF16 operation `llvm.cvtneps2bf16` depending on
+    the width of MLIR vectors it is applied to.
+
+    #### From the Intel Intrinsics Guide:
+
+    Convert packed single-precision (32-bit) floating-point elements in `a` to
+    packed BF16 (16-bit) floating-point elements, and store the results in `dst`.
+
+    Example:
+    ```mlir
+    %dst = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
+    ```
+  }];
+  let arguments = (ins VectorOfLengthAndType<[8, 16], [F32]>:$a);
+  let results = (outs VectorOfLengthAndType<[8, 16], [BF16]>:$dst);
+  let assemblyFormat =
+    "$a attr-dict `:` type($a) `->` type($dst)";
+}
+
+def CvtNeF32ToBF16Ps256IntrOp : AVX512_IntrOp<"cvtneps2bf16.256", 1, [Pure],
+    /*extension=*/"bf16"> {
+  let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
+  let results = (outs VectorOfLengthAndType<[8], [BF16]>:$res);
+}
+
+def CvtNeF32ToBF16Ps512IntrOp : AVX512_IntrOp<"cvtneps2bf16.512", 1, [Pure],
+    /*extension=*/"bf16"> {
+  let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$a);
+  let results = (outs VectorOfLengthAndType<[16], [BF16]>:$res);
+}
+
 //===----------------------------------------------------------------------===//
 // AVX op definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index 260ac9ce589a38f..f1fbb39b97fc498 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -131,6 +131,39 @@ struct DotBF16OpConversion : public ConvertOpToLLVMPattern<DotBF16Op> {
   }
 };
 
+struct CvtPackedF32ToBF16Conversion
+    : public ConvertOpToLLVMPattern<CvtPackedF32ToBF16Op> {
+  using ConvertOpToLLVMPattern<CvtPackedF32ToBF16Op>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(CvtPackedF32ToBF16Op op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto typeA = dyn_cast<VectorType>(op.getA().getType());
+    unsigned elemBitWidth = typeA.getElementTypeBitWidth();
+    unsigned opBitWidth = typeA.getShape()[0] * elemBitWidth;
+
+    auto opType = op.getDst().getType();
+    auto opA = op.getA();
+
+    switch (opBitWidth) {
+    case 256: {
+      rewriter.replaceOpWithNewOp<CvtNeF32ToBF16Ps256IntrOp>(op, opType, opA);
+      break;
+    }
+    case 512: {
+      rewriter.replaceOpWithNewOp<CvtNeF32ToBF16Ps512IntrOp>(op, opType, opA);
+      break;
+    }
+    default: {
+      return rewriter.notifyMatchFailure(
+          op, "unsupported AVX512-BF16 packed f32 to bf16 variant");
+    }
+    }
+
+    return success();
+  }
+};
+
 struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> {
   using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;
 
@@ -202,8 +235,10 @@ using Registry = RegistryImpl<
 void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
   Registry::registerPatterns(converter, patterns);
-  patterns.add<MaskCompressOpConversion, DotBF16OpConversion, RsqrtOpConversion,
-               DotOpConversion>(converter);
+  patterns
+      .add<MaskCompressOpConversion, DotBF16OpConversion,
+           CvtPackedF32ToBF16Conversion, RsqrtOpConversion, DotOpConversion>(
+          converter);
 }
 
 void mlir::configureX86VectorLegalizeForExportTarget(
@@ -215,6 +250,9 @@ void mlir::configureX86VectorLegalizeForExportTarget(
   target.addLegalOp<DotBF16Ps256IntrOp>();
   target.addLegalOp<DotBF16Ps512IntrOp>();
   target.addIllegalOp<DotBF16Op>();
+  target.addLegalOp<CvtNeF32ToBF16Ps256IntrOp>();
+  target.addLegalOp<CvtNeF32ToBF16Ps512IntrOp>();
+  target.addIllegalOp<CvtPackedF32ToBF16Op>();
   target.addLegalOp<RsqrtIntrOp>();
   target.addIllegalOp<RsqrtOp>();
   target.addLegalOp<DotIntrOp>();
diff --git a/mlir/test/Dialect/X86Vector/cvt-packed-f32-to-bf16.mlir b/mlir/test/Dialect/X86Vector/cvt-packed-f32-to-bf16.mlir
new file mode 100644
index 000000000000000..c97c52f01c3b033
--- /dev/null
+++ b/mlir/test/Dialect/X86Vector/cvt-packed-f32-to-bf16.mlir
@@ -0,0 +1,24 @@
+// REQUIRES: target=x86{{.*}}
+
+// RUN: mlir-opt %s \
+// RUN:   -convert-vector-to-llvm="enable-x86vector" -convert-to-llvm \
+// RUN:   -reconcile-unrealized-casts | \
+// RUN: mlir-translate --mlir-to-llvmir | \
+// RUN: llc -mcpu=sapphirerapids | \
+// RUN: FileCheck %s
+
+func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
+    %a: vector<8xf32>) -> vector<8xbf16> {
+  %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
+  return %0 : vector<8xbf16>
+}
+// CHECK-LABEL: avx512bf16_cvt_packed_f32_to_bf16_256:
+// CHECK: vcvtneps2bf16{{.*}}%xmm
+
+func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
+    %a: vector<16xf32>) -> vector<16xbf16> {
+  %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
+  return %0 : vector<16xbf16>
+}
+// CHECK-LABEL: avx512bf16_cvt_packed_f32_to_bf16_512:
+// CHECK: vcvtneps2bf16{{.*}}%ymm
diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index ed9177eaec9ce4a..59be7dd75b3b0b8 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -70,6 +70,24 @@ func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
   return %0 : vector<16xf32>
 }
 
+// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_256
+func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
+  %a: vector<8xf32>) -> (vector<8xbf16>)
+{
+  // CHECK: x86vector.avx512.intr.cvtneps2bf16.256
+  %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
+  return %0 : vector<8xbf16>
+}
+
+// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_512
+func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
+  %a: vector<16xf32>) -> (vector<16xbf16>)
+{
+  // CHECK: x86vector.avx512.intr.cvtneps2bf16.512
+  %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
+  return %0 : vector<16xbf16>
+}
+
 // CHECK-LABEL: func @avx_rsqrt
 func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
 {
diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index cf74a7ee602558f..0d00448c63da889 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -74,6 +74,26 @@ func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
   return %0 : vector<16xf32>
 }
 
+// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_256
+func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
+  %a: vector<8xf32>) -> (vector<8xbf16>)
+{
+  // CHECK: x86vector.avx512.cvt.packed.f32_to_bf16 {{.*}} :
+  // CHECK-SAME: vector<8xf32> -> vector<8xbf16>
+  %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
+  return %0 : vector<8xbf16>
+}
+
+// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_512
+func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
+  %a: vector<16xf32>) -> (vector<16xbf16>)
+{
+  // CHECK: x86vector.avx512.cvt.packed.f32_to_bf16 {{.*}} :
+  // CHECK-SAME: vector<16xf32> -> vector<16xbf16>
+  %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
+  return %0 : vector<16xbf16>
+}
+
 // CHECK-LABEL: func @avx_rsqrt
 func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
 {
diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir
index 1df03f10c93214a..db1c10cd5cd37a2 100644
--- a/mlir/test/Target/LLVMIR/x86vector.mlir
+++ b/mlir/test/Target/LLVMIR/x86vector.mlir
@@ -62,37 +62,57 @@ llvm.func @LLVM_x86_vp2intersect_q_512(%a: vector<8xi64>, %b: vector<8xi64>)
 
 // CHECK-LABEL: define <4 x float> @LLVM_x86_avx512bf16_dpbf16ps_128
 llvm.func @LLVM_x86_avx512bf16_dpbf16ps_128(
-    %arg0: vector<4xf32>, %arg1: vector<8xbf16>, %arg2: vector<8xbf16>
+    %src: vector<4xf32>, %a: vector<8xbf16>, %b: vector<8xbf16>
   ) -> vector<4xf32>
 {
   // CHECK: call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(
-  %0 = "x86vector.avx512.intr.dpbf16ps.128"(%arg0, %arg1, %arg2)
+  %0 = "x86vector.avx512.intr.dpbf16ps.128"(%src, %a, %b)
     : (vector<4xf32>, vector<8xbf16>, vector<8xbf16>) -> vector<4xf32>
   llvm.return %0 : vector<4xf32>
 }
 
 // CHECK-LABEL: define <8 x float> @LLVM_x86_avx512bf16_dpbf16ps_256
 llvm.func @LLVM_x86_avx512bf16_dpbf16ps_256(
-    %arg0: vector<8xf32>, %arg1: vector<16xbf16>, %arg2: vector<16xbf16>
+    %src: vector<8xf32>, %a: vector<16xbf16>, %b: vector<16xbf16>
   ) -> vector<8xf32>
 {
   // CHECK: call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(
-  %0 = "x86vector.avx512.intr.dpbf16ps.256"(%arg0, %arg1, %arg2)
+  %0 = "x86vector.avx512.intr.dpbf16ps.256"(%src, %a, %b)
     : (vector<8xf32>, vector<16xbf16>, vector<16xbf16>) -> vector<8xf32>
   llvm.return %0 : vector<8xf32>
 }
 
 // CHECK-LABEL: define <16 x float> @LLVM_x86_avx512bf16_dpbf16ps_512
 llvm.func @LLVM_x86_avx512bf16_dpbf16ps_512(
-    %arg0: vector<16xf32>, %arg1: vector<32xbf16>, %arg2: vector<32xbf16>
+    %src: vector<16xf32>, %a: vector<32xbf16>, %b: vector<32xbf16>
   ) -> vector<16xf32>
 {
   // CHECK: call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(
-  %0 = "x86vector.avx512.intr.dpbf16ps.512"(%arg0, %arg1, %arg2)
+  %0 = "x86vector.avx512.intr.dpbf16ps.512"(%src, %a, %b)
     : (vector<16xf32>, vector<32xbf16>, vector<32xbf16>) -> vector<16xf32>
   llvm.return %0 : vector<16xf32>
 }
 
+// CHECK-LABEL: define <8 x bfloat> @LLVM_x86_avx512bf16_cvtneps2bf16_256
+llvm.func @LLVM_x86_avx512bf16_cvtneps2bf16_256(
+  %a: vector<8xf32>) -> vector<8xbf16>
+{
+  // CHECK: call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256(
+  %0 = "x86vector.avx512.intr.cvtneps2bf16.256"(%a)
+    : (vector<8xf32>) -> vector<8xbf16>
+  llvm.return %0 : vector<8xbf16>
+}
+
+// CHECK-LABEL: define <16 x bfloat> @LLVM_x86_avx512bf16_cvtneps2bf16_512
+llvm.func @LLVM_x86_avx512bf16_cvtneps2bf16_512(
+  %a: vector<16xf32>) -> vector<16xbf16>
+{
+  // CHECK: call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512(
+  %0 = "x86vector.avx512.intr.cvtneps2bf16.512"(%a)
+    : (vector<16xf32>) -> vector<16xbf16>
+  llvm.return %0 : vector<16xbf16>
+}
+
 // CHECK-LABEL: define <8 x float> @LLVM_x86_avx_rsqrt_ps_256
 llvm.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
 {
@@ -103,11 +123,11 @@ llvm.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
 
 // CHECK-LABEL: define <8 x float> @LLVM_x86_avx_dp_ps_256
 llvm.func @LLVM_x86_avx_dp_ps_256(
-    %arg0: vector<8xf32>, %arg1: vector<8xf32>
+    %a: vector<8xf32>, %b: vector<8xf32>
   ) -> vector<8xf32>
 {
   // CHECK: call <8 x float> @llvm.x86.avx.dp.ps.256(
-  %0 = llvm.mlir.constant(-1 : i8) : i8
-  %1 = "x86vector.avx.intr.dp.ps.256"(%arg0, %arg1, %0) : (vector<8xf32>, vector<8xf32>, i8) -> vector<8xf32>
+  %c = llvm.mlir.constant(-1 : i8) : i8
+  %1 = "x86vector.avx.intr.dp.ps.256"(%a, %b, %c) : (vector<8xf32>, vector<8xf32>, i8) -> vector<8xf32>
   llvm.return %1 : vector<8xf32>
 }

Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as before. Not the best design, but following the existing pattern. This is helping us understand the design and will be target of an RFC soon on how to simplify CPU dialect lowering without needing a full blown VM dialect.

@adam-smnk adam-smnk merged commit 2b71df5 into llvm:main Feb 18, 2025
12 checks passed
wldfngrs pushed a commit to wldfngrs/llvm-project that referenced this pull request Feb 19, 2025
Adds AVX512 bf16 conversion from packed f32 to bf16 elements.

Tests are slightly refactored to better follow file's convention.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants